# -*- coding: utf-8 -*-
import torch
import torch.nn as nn


class GCBlock(nn.Module):
    def __init__(self, in_channels, scale=16):
        super(GCBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = self.in_channels // scale
        self.conv_key = nn.Conv2d(self.in_channels, 1, 1)
        self.soft_max = nn.Softmax(dim=1)
        self.conv_value = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, 1),
            nn.LayerNorm([self.out_channels, 1, 1]),
            nn.ReLU(),
            nn.Conv2d(self.out_channels, self.in_channels, 1),
        )

    def forward(self, x):
        n, c, h, w = x.size()
        out = self.conv_key(x).view(n, 1, -1).permute(0, 2, 1).view(n, -1, 1).contiguous()
        key = self.soft_max(out)
        query = x.view(n, c, h * w)
        out = torch.matmul(query, key)
        out = out.view(n, c, 1, 1).contiguous()
        out = self.conv_value(out)
        out = x + out
        return out


